from flask import abort, jsonify

from flask_restful import Resource, reqparse, fields, marshal, marshal_with
from app.models.users import User, RevokedTokenModel
from ext import db, jwt
from flask_jwt_extended import (create_access_token, create_refresh_token, jwt_required, get_jwt_identity, get_jwt)
from datetime import datetime
from app.utils.code_msg import ResponseCode
from app.utils.response import ResMsg

"""
# The following callback functions have all been changed to take two arguments.
# Those arguments are the jwt_headers and jwt_payload.
- @jwt.needs_fresh_token_loader
- @jwt.revoked_token_loader
- @jwt.user_lookup_loader
- @jwt.user_lookup_error_loader
- @jwt.expired_token_loader
- @jwt.token_in_blocklist_loader
- @jwt.token_verification_loader
- @jwt.token_verification_failed_loader
- @jwt.expired_token_loader

参考：https://flask-jwt-extended.readthedocs.io/en/stable/v4_upgrade_guide/#callback-function-changes
"""


@jwt.expired_token_loader
def expired_token_callback(jwt_headers, jwt_payload):
    """
    过期令牌，主要是处理有效但是过期的令牌在访问一个受保护的endpoint之前
    """
    return jsonify(code="401", err="token 已过期"), 401


@jwt.unauthorized_loader
def my_unauthorized_token_callback(jwt_headers):
    """未传头部Authorization"""
    return jsonify(err="Missing Authorization Header"), 401


@jwt.invalid_token_loader
def invalid_token_callback(error):  # we have to keep the argument here, since it's passed in by the caller internally
    """
    无效令牌
    """
    return jsonify({
        'message': 'Signature verification failed.',
        'error': 'invalid_token'
    }), 401


# 格式化输出数据，输出的json格式如下
users_fields = {
    'id': fields.Integer,
    'email': fields.String,
    'username': fields.String,
    'add_time': fields.DateTime(dt_format='iso8601'),
    'login_time': fields.DateTime(dt_format='iso8601'),
    'is_active': fields.Boolean,
    'uri': fields.Url(absolute=True)
}

users_list_fields = {
    'count': fields.Integer,
    'users': fields.List(fields.Nested(users_fields)),
}


class Register(Resource):
    def __init__(self):
        self.reqparse = reqparse.RequestParser()
        self.reqparse.add_argument('email', type=str, required=True, location='json', help='email不能为空')
        self.reqparse.add_argument('username', type=str, required=True, location='json', help="用户名不能为空")
        self.reqparse.add_argument('password', type=str, required=True, location='json', help="密码不能为空")
        super(Register, self).__init__()
        self.res = ResMsg()

    # 注册用户
    def post(self):
        args = self.reqparse.parse_args()
        if 'username' not in args or 'email' not in args or 'password' not in args:
            abort(400)
        # 从args获取request,将密码加密处理
        email = args['email']
        username = args['username']
        password = args['password']
        exist_user = User.query.filter_by(username=username).first()
        if not exist_user:
            user = User(email=email, username=username, password=User.set_password(User, password))
            db.session.add(user)
            db.session.commit()
            self.res.update(code=ResponseCode.Success)
            return jsonify(self.res.data)
        else:
            self.res.update(code=ResponseCode.RequestToRepeat, msg="用户已经存在")
            return jsonify(self.res.data)


class Login(Resource):
    def __init__(self):
        self.reqparse = reqparse.RequestParser()
        self.reqparse.add_argument('username', type=str, required=True, location='json', help="用户名不能为空")
        self.reqparse.add_argument('password', type=str, required=True, location='json', help="密码不能为空")
        super(Login, self).__init__()
        self.res = ResMsg()

    # 用户登录
    def post(self):
        args = self.reqparse.parse_args(strict=True)
        username = args['username']
        password = args['password']
        if (not username or not password):
            self.res.update(code=ResponseCode.Fail, msg="用户名和密码不能为空")
            return jsonify(self.res.data)
        else:
            userInfo = User.query.filter_by(username=username).first()
            if (userInfo is None):
                self.res.update(code=ResponseCode.Fail, msg="找不到用户")
                return jsonify(self.res.data)
            else:
                if (User.check_password(User, userInfo.password, password)):
                    login_time = str(datetime.now())
                    userInfo.login_time = login_time
                    User.update(User)
                    access_token = create_access_token(identity=username, fresh=True)
                    refresh_token = create_refresh_token(identity=username)
                    # get_jwt_identity() 可以从token中获取到username，这点在实际项目中非常有用。
                    result_msg = dict(access_token=access_token, refresh_token=refresh_token)
                    self.res.update(code=ResponseCode.Success, data=result_msg, msg="登录成功")
                    return jsonify(self.res.data)

                else:
                    self.res.update(code=ResponseCode.Fail, msg="密码不正确")
                    return jsonify(self.res.data)


class UserList(Resource):
    # 查看所有用户
    # 需要验证token
    @jwt_required()
    def get(self):
        user_list = User.query.all()
        return ({
            "count": len(user_list),
            "data": [marshal(t, users_fields) for t in user_list]
        })


class Userinfo(Resource):
    def __init__(self):
        self.reqparse = reqparse.RequestParser()
        self.reqparse.add_argument('username', type=str, location='json')
        self.reqparse.add_argument('password', type=str, location='json')
        self.reqparse.add_argument('email', type=str, location='json')
        self.reqparse.add_argument('is_active', type=bool, location='json')
        super(Userinfo, self).__init__()
        self.res = ResMsg()

    # 查看单个用户
    @marshal_with(users_fields)
    # 需要验证token
    @jwt_required()
    def get(self, id):
        user = User.query.get_or_404(id)
        return user

    # 更新某个用户
    @marshal_with(users_fields)
    # 需要验证token
    @jwt_required()
    def put(self, id):
        if not id:
            abort(400)
        user = User.query.get_or_404(id)

        # args = self.reqparse.parse_args()
        args = self.reqparse.parse_args(strict=True)

        if args["username"]:
            user.username = args["username"]
        if args["password"]:
            password = args["password"]
            # 对新密码进行加密
            user.password = User.set_password(User, password)
        if args["email"]:
            user.email = args["email"]
        if args["is_active"]:
            user.is_active = args["is_active"]

        # 第二种实现方式
        # if 'username' in request.json:
        #     user.username = request.json['username']
        #
        # if 'password' in request.json:
        #     password = request.json['password']
        #     user.password = User.set_password(User, password)
        #
        # if 'email' in request.json:
        #     user.email = request.json['email']
        #
        # if 'is_active' in request.json:
        #     user.is_active = request.json['is_active']
        User.update(user)
        return user

    # 删除某个用户
    # 需要验证token
    @jwt_required()
    def delete(self, id):
        if not id:
            abort(400)
        user = User.query.get_or_404(id)
        db.session.delete(user)
        db.session.commit()
        self.res.update(code=ResponseCode.Success, msg="用户删除成功")
        return jsonify(self.res.data)


class TokenRefresh(Resource):
    """
    刷新access token额外的接口：
    参考:https://www.jianshu.com/p/c155c2b7af42
    """

    @jwt_required(refresh=True)
    def post(self):
        current_user = get_jwt_identity()
        new_access_token = create_access_token(identity=current_user)
        return {'access_token': new_access_token}


@jwt.token_in_blocklist_loader
def check_if_token_in_blacklist(jwt_header, decrypted_token):
    """
    回调函数，每次客户端请求被保护的接口时都会调用这个函数，函数要根据token是否在blocklist返回True或False
    """
    jti = decrypted_token['jti']
    return RevokedTokenModel.is_jti_blacklisted(jti)


# 登出和注销
class UserLogoutAccess(Resource):
    @jwt_required()
    def post(self):
        jti = get_jwt()['jti']
        try:
            revoked_token = RevokedTokenModel(jti=jti)
            revoked_token.add()
            return {'message': 'Access token has been revoked'}
        except:
            return {'message': 'Something went wrong'}, 500


class UserLogoutRefresh(Resource):
    @jwt_required(refresh=True)
    def post(self):
        jti = get_jwt()['jti']
        try:
            revoked_token = RevokedTokenModel(jti=jti)
            revoked_token.add()
            return {'message': 'Refresh token has been revoked'}
        except:
            return {'message': 'Something went wrong'}, 500
